{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import math\n",
    "import time\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.utils.data as Data\n",
    "import random\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython.core.interactiveshell import InteractiveShell \n",
    "InteractiveShell.ast_node_interactivity = 'all'\n",
    "from pydataset import data  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们使⽤线性模型参数w = [2, −3.4]⊤、b = 4.2 和噪声项ϵ⽣成数据集及其标签：\n",
    "y = Xw + b + ϵ."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "features: tensor([-0.2054,  0.1071]) \n",
      "label: tensor([3.4058])\n"
     ]
    }
   ],
   "source": [
    "def synthetic_data(w, b, num_examples): #@save\n",
    "    \"\"\"⽣成y=Xw+b+噪声\"\"\"\n",
    "    X = torch.normal(0, 1, (num_examples, len(w)))\n",
    "    y = torch.matmul(X, w) + b\n",
    "    y += torch.normal(0, 0.01, y.shape)\n",
    "    return X, y.reshape((-1, 1))\n",
    "true_w = torch.tensor([2, -3.4])\n",
    "true_b = 4.2\n",
    "features, labels = synthetic_data(true_w, true_b, 1000)\n",
    "print('features:', features[0],'\\nlabel:', labels[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD4CAYAAADxeG0DAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nO2dfXRU533nvw8v4yKJGEkMKm9GQpLlKClRwhi7GOzwltoph7Q9xanbs6XeniXdc0pYN3tO6tRn3XbddbddhyX+o4Fu4tKzrRv7tD3hsHHXgAkgqMHCJdRRENIg8b54NBIEaagHyc/+cee5eubOvXfuzNx5uaPv5xzOSKP78twR+t7f/b0KKSUIIYQElxnlXgAhhJDCoJATQkjAoZATQkjAoZATQkjAoZATQkjAmVWOk86fP182NzeX49SEEBJYzpw5MyylDFvfL4uQNzc3o6enpxynJoSQwCKEuGT3Pl0rhBAScCjkhBAScCjkhBAScCjkhBAScCjkhBAScCjkhBAScCjkhBAScCjkFkbGk9hzNIqR8WS5l0IIIZ6gkFt4s+cKXn7rPN7suVLupRBCiCfKUtlZyWyNLE17JYSQSodCbqGhNoSvPNFa7mUQQohn6FohhJCAQyEPAAzAEkLcoJAHAAZgCSFu0EceABiAJYS44dkiF0J8VwjxoRDiA+29PxRCXBNCnE39+2Jxljm9UQHYhtpQuZdCCKlAcnGt/BWAJ23e3yWl7Er9+4E/yyKEEOIVz0IupTwGYKSIayGEEJIHfgQ7f1cIcS7leql32kgIsV0I0SOE6InFYj6clhBCCFC4kP8FgFYAXQBuAHjFaUMp5V4pZURKGQmHM2aHEkIIyZOChFxKeVNKOSml/BjAXwJY5c+yCCGEeKUgIRdCLNS+/WUAHzhtSwghpDjkkn74OoB/BtAhhLgqhPhtAH8mhPhXIcQ5AOsAPFekdRZEKSsjq/VchJDKxXNBkJTyGZu3v+PjWoqGqowEULSGWCPjSbzZcwWJ5AR2Hx4o6rkU6roSyUnUhGZia2Qpc80JmYZMi8rOUlRGKlHduaEdzz/1UEmqMNU5EsmJot+oCCGVS1UJubKKrZZpKVrT6jeLUlnF6rpGxpOoCc1iCT8h05SqEvJSuFCcKGcfc/ZQJ2R6U1VCzuZShJDpSFW1sWVzKWayEDIdqSohJ+xdTsh0JFBCXoi16bRvoRZspVnAWyNLzayZSlsbIaQ4BErIC7E2nfYt1IItlgWcrwjr7iVa54RMDwIV7CwkmOm0b6EB0lz3d0qRtOJHBg6Dv4RMD4SUsuQnjUQisqenp+TnrQT2HI3i5bfO4/mnHnIVaK+CTwiZPgghzkgpI9b3A+VaCQpubhHdh+22f7lEnH51QoIHhbwIuPmmvaRIltO3Tb86IcEjUD5yPymm1Vtqv7uflOrcdB0R4h/TVsiLWc5faMl8tZf7j4wn8bU3zuJInzHyj+0FCCmMaSvkzOgoH2/2XMGRvhjWdYT5+RPiA9NWyHXLk4/5paUcnSIJqWYY7ETpA3ylzgzRz1cJWSnsiUOIv0xbi1yn1G4WN/+816cDu+2c9tXPB4BDKAipMijkKH1w0e3GoYvu1shSR7GOjyWx9/hFJJITeG5TR8a++vXYnY++aUKqBwp5GXC7cWyNLEUiOYlEcgL7Tg5h9+F+AFPCrMR6TVtjag+Rtq/+6nS+XG5aViuf8QRCKg/6yEuEV990Q20INaGZqQHOMqMKVFWG/tGXPo3nn3oIW7oWmccthu/ZGj/Qv68EfzshhBZ5ycjmMtFxy+rQrevWJ+rM3i2Ad0s7F6vaauXrr+UcrUcImYJCXiJyEUA314tVhL0Eaq37eBVgO8HX18ZcfEIqAwp5ifBLAK0i7EX0E8mJlKvG2Mfr+Qu54bhRSGYOISQTCnkZKCRLRhdhN6HTy+B3bmhP87VnO7867sbOprRz+oXXJwK6bgjxBoW8DBRiaeoi7OYf18vgt61uzuk8xRZQr08EdN0Q4g0KeRmwE8p8xN1N6Aopg/dbQK3X5vWJpJzNwwgJEkw/9IDfaXZ2wyXyaRPglm6oAqEqTTAX/E5jtF4b0xYJ8Rda5B7w29VgZ2kWw41QKT5m67VVyroIqRYo5B4oha+2GG6ESvExW6+tUtZFSLVA14oHCnU1FNuV4HT8XNZdSneH27rodiEkdyjkJaDYbXL9OH6lzOqslHUQEiQ8u1aEEN8FsBnAh1LKT6feawDwPQDNAIYAPC2lHPV/mcFlZDyJRHICOze0F82V4OSqyKcUf2NnE/Ycjbruk0sLXb+uhRDiTC4W+V8BeNLy3u8DOCylbAdwOPU90Xiz5wp2Hx5ATWhm0aoTnVwVuVi36hiHem9m3cfuuOq9r71xtiC3CIdOEJI7ni1yKeUxIUSz5e0vAfh86ut9AH4I4Os+rKtqKKaFmc0Ktp7bi9Wstn24uQHPvnYaL2zuRGu4Lm1/u4rPrZGlePdiHEf6Yniz5wqzUQgpIYX6yJuklDcAIPW6oPAlVRfFtDCdLG4VMASQdm4vFrpa76vv9ONIXwwvHejNON9LB3ozbgYNtSG88nRXRn58rjDYSUjulCzYKYTYLoToEUL0xGKxUp22ajF875PYuaHN7LuiBNBJsO0KkZx4YXMn1nWE8cLmzrT913WETatbX4vdjcPrdejCzWAnIblTaB75TSHEQinlDSHEQgAfOm0opdwLYC8ARCIRWeB5y0Y5OvLZndPwvffj+aceQkNtKK3vitdJQU7HBoDWcB1ee3ZVxv6vPN1lbj/VXXEyY5KRV6zFQbkEXQkhBoUK+X4A2wD8aer1+wWvqMIpR1Wi3TndBj7kUlyU7XqchH40kcRLB3pT3RXb8napWNev1p7PwAxCpiu5pB++DiOwOV8IcRXAizAE/A0hxG8DuAxgazEWWUmUIz3O7pxWsc63MlSfEarGxelYhV59rwKbRnfFlrytZqd1l/JzZt9zEnRyyVp5xuFHG3xaSyAoR0e+XCYG5ToseTSRxIFz1xGNjePc1dt45emutO2cLP+NnU14dPnNvIYye9m22J+zvgb2fiFBh71WAo6Txez0vZWXDvQiGhtHQ+1s29RBJagqKLk1sjRtZqjTOuzww6eeC243DOsMVf2VkKBBIQ84br5yu1crRlZKL3asb8d7QyOO2+kuFavV7uU86ROL0n3qflWK5jKbNN+YAiGVCHutBBw9T91tWLKTGKrslM8tq3fdbmNnE1rDtRmph9Z1AMCugxew62BfWi64PrFoS9fitH3dKkWzpSHapV2q6lK3dEs27iLVBC3yKiIX94bu2953chCAcB0Jd6j3JqKxcazrCLu6IPadHDLdJjWhWWlZNonkBACB/Wev2Q6D1lMO8xkQbVddmo+lrY6ZSE6iJjTTN78/IcWCQl4F5DIs2Sr2+04OmqJaE5rpKHwbO5tw7EIM7U1z086ZKVxGicBjrY0ZWTY1oVl4+a3zjsOgrSmHXkTY6iLR89ztPqNcGoglkhOeg6AMmJJyQiEPMFPBw4k0C9eNTEtXAADWtM139W+/dKAXJ6JxnIjG0ZgSQjurddvqltQxM2u+ss0RzSfo6DUNMxeh1QO8NaFZntbDgCkpJxTyAKPEyWrhumEVum2rm1ETmomNnU2uGR5H+mJ4rLURn1p8PxLJCdPPbWe19gyN4EQ0DgB4blNHxrn1DBhrv5ZiWbMbO5vw7sW4+dTihVzWw4ApKScU8gCj+5YP9d70vJ/VzZCtklK3NtXN48ylUaxc1oAtXYtRE5pl+rcTyQlTxJW175RN4uaD9tvnfKj3Jo70xfDo8ptpaZOEVAMU8gCTazm7mytGF2s3EdUDit0DhmDXhGbie+9dwd5jF/FISwO2r23BnNAsbFvdDMC+n0oiOZlmuVvX7bfPuRDXhx+pkIQUEwp5FZBrhoedK0Z3DVhvDCoTJZGcxHObHsQrT3eZmS6AxMtvncdjrY0AgFODI1j/0ALXYctG4HMmTkTjeKy1EYnkZEZ7AL99zoW4PvK5qTD4SUoJhTxAOFl5XkUqW7DRbjsDmfbaUBsyfd8qILixswn7z14DIGwLfazrUz7r9qa52H24PyNjJpdOjXb4aRE73VT0bKFDvTfTzsXgJyklFPIAkW+nQoVXwbdut6VrMc5dvZ1RyGPdVhd35S/ffXgAieSEmf2h1qV81iuW3J/xdGAVSPWaS3aOnxZxtkwY5WrSz8XgJyklFPIAkc3Ky0e8nMTfKBQagrLC3QKFTsHM7Y8vx7qOMO4mP8buw+5tePWiJKtAqlflEtrY2YRdBy8AkNjStTjDGgbyy1LJFbsGYoSUAwp5gMjmbvAasNRxEn81uAIwcszVJCI7VFFRIjmB5zZ1pBXUHOmLYWlDDVrDtXi4ucH2WvYcjZqW9rmrt8ypREogrZ0Wje2NtR04dwPR2HjG+p2yVPJ1ubi1PwDATBhSVijkAccqxE4BSyecrHw9s6R7YBhr2+e7CJ9Ie7UW1BzvjyEaG8er7/TjtWdXZYiiKt8/c2nUFF+V6mjnf97Y2YTj/cNITkzi9NAo1nWEU1Z6H5RV73Rd+bpcGLwklQyFPOC4CbHd+1acfLlGQPPBNNF1YkvXIpy7egtbuhaZ7+n7bexswksHek1LeyqPfMp3/tymDtse4cf7h9E9MIzj/cP41jOfRUNtCId6b6J7YBg7N7RjwyebsDWy1LbVgH5dubQxsIPBS1LJUMirFK/BNqtQu7kP7PZRwmp1Y1gtWDX/c2Q8ifhYEmvaGjE6fi/lkjHSGvVzqfXExz5C94Ah5qoRln32jfE0sLR+DuJjSTOdMVsbA6+uFgYvSSVDIQ842fqEe91f4cV94DQwWbdWlWvGOkLuzZ4r2Hv8IgBAprIae4ZGMvLIdffMnNBM6GmNdqK6bXUzzl29hSN9Mew9fhGNdaG0wRp6oFRvD0CXCakGKOQBxBrgtLZudds+20CIRHIC8fEkdh3sc5zFac0IsRNWVfTz8lvnHdvZbulaZA5wdlq7nrPuhup8+O0fDqD3xp0MF4oeKH35rfOIjyXR/+Ed7FifvU9NMao0WflJ/IRCHkCsVqRT61an7e16rQCGuJy7etvMidYFWP1838lBnLl0C90Dw1n7lijRVpWbai36DUKt3Wop2+ElT76x7j50DwziUO9N1EdCGdurm9CPro7i1OAoAJhuHyeKYbXzSYD4CYW8QMphWeUaeLNu75ZyeKQvhjVt87Fy2TzbjA/lZ1bphLr42vnbAZiVmwAy3EDWfjGJ5GTqbBJPPLgAr77Tjxc2d6I1XJcxRMLuc9dzu9Voufh4Ev03Dev71Xf6caQvhkda6rGmbb4ZgPXz8/YCg6fETyjkBVIOy8rqysi2Buv2XjJdlDDrQj2VJmhY5EoU1Xn1robKX72mrRE7N7Sbx3ZyA+muHXWzmMoR78UrT3chkZzAzg3taSKtzm1tB7DnaNQcLdczNIL3L9/C+5dHcfvuBFrDtTg1OIrnn3oIreHs+d/WpxY/btwMnhI/oZAXSCVYVnZrcBMct5RDtxuE8lfrqXwq53tkPGkKLSBxpC+G1nAtugfiWNse9uxKMdoAGM24nnhwAb558ALaF9SlGncNYOeGNtOv3hquNX3h1sZe+mfyH/93DwCYIv7nv/oZc9B0rsLs5amAkFJDIS+QSrCs7Nbgx5NCtjL3+pqpzBJlIavA4VQjretpmStOrXf19T636UHzHGvb56eyTtqwc0Ob+TTQ3FiDaGwcL37/x/jWM5+FXWMvdd1ff/KT+L03zuITPzML5679FEcvxMxz6C4dt/mcdnnoleDnZtCUABTyQOLlj9ePJwU9P1wFDlVxj9WlotwY1gCqylwxvp6luVAm08r+vbh73uy5gu6BYazrCKN9QR32Hh8088u3rW5JO340NmYWIb03NIKheALNjTUAgH+ODuNuchL/NjGJvv/3U/zmzy9z7Y0OZA5kBuxvdHqPGqesHz+phJsJKT8U8gDi5Y/Xy5NCthuCLqKqcvLYhRhOROOmaOvb2Y2LmyrsSWL34fNm6uHuw/3YuaE9p8Ij/Tz7z143B1hYbx4AzJvNvckPsHJZAx5rbcSJaBz1NbNxemgUp4dGzW2Hx5KIxsbNUn87l4/uw1cB2xVL7rcthFJ9YM5dvY0XNneaLQbUz/20nivBtUfKD4U8gOTTHMuO3IKkRuXkpxbfj8cfnLK8dZeD1VLXj2F0K1THSXeD2DEynsRXX/8XdA8MZ1R+qqZZyo1jd/1GNkovltbXYPfhfmxfuxyhWTPM9T3SUo+On/0ELsbG8HubOkyfudNnohcoqRTNFUvmZRQZ6T1qjHMZn4keAHb6vPOhElx7pPxQyCuAXMXYbZpPLuRizakhzfoadd+4ykbRLXW7/dUAip0b2rFtdbNL7/FJdA8Mp/ZOF3yru8Xu+lvDdXjt2VXmDaT3xm383qYO3JuU6Fw4F7/z+baMlEnd6nf6TFThkb6P9Xeg96hRAWHVCdLp8yGkECjkFUAhfs5CHq1zseacAqpKnF7Y3JnWatZp/10H+7D78AC2r21J9UCZxO7D/Ta9x9vMDJhtq1sc15ItY0cv3e+98R5Gxu9hbft8AEgNv5g0XSG5fBb6k4DdGqwtblUnSAYlSTGgkFcApRLjXFGVnKo1rFt5f0NtyGNPbsNF03vjDroHBrFzQ5vpnrDrPZ5tfXZPMtYUwRVL7sfg8DiG4gk01M42ffkqG0Z3j+h57G7phXbtg6159zq6a0Zto45DcSeFQiGvACrVz6lXclpnagKZhTJesjV0F4u1z7i6ETjdEJwmEenVoNtWt2S4XtQTwP/98U1cGklg/9lrppVvdY/ozbWsRUc6Dzc3mNWt1tYFdtvrn2muTcrsrp0QHQo5cURVcuqdB53QszVUjxZryb7Vus91qo7V0lYpjIDUbjiz0nqRqxTBL696AHNCs1JrFBk3T7Xdlq5FaA3XpVWG2rlt9GEZjy5vNM+vtncSXrunLy+FXEwzJG5QyKcJ+Vh0XjsPAlPCCkjbni4ATLED0ot+vGKXDrlzQ1vKuhZp51aoXHigFzvWt2cMwACMz0Zl3KhUQqvbSGHOI127HADQvmAuNnY2IT72EXpv3MELmzsxmkhi+1/3ZIygs/sdeC3kYpohccMXIRdCDAG4A2ASwISUMuLHcYl/FNuiUxOFFEbJvmExq0yVR1oacGpwBNYsFGvmSraAqcHUeDl9mpHVvbM1MtXm996kRPfAMFYsuZ621n0nh8xmYdmE0i5jprEuveviuxfjiMbG0RquTTteLk2/rOuoVPcbqQz8tMjXSSmHs29GykGxLTo7/7XK9T7Ue9O0ntc/tCBjDUrgDv/kJk4PjSI+9hG+8YvuXQn1dEj9OMq9AwDPbepISxdU04buJicsQUnjxrJy2by0ARhu+eTA1OzQ+NhH+PKqBwAYvvN3zn+IR1rq8d9+ZUWGSKthG6o3jNvx9epUL829yPSFrpVpQr4WnVeXjNepQfoxrP1L3jn/IQAjo8ULieQEvn00ijmzZ2Db6hZs7GzC66cvYyiewJTFnjlt6N2LIzg1OGIWGlnL+9XTxPa1LbZDNvTxcWoMXWPdffjKE6149rXTODU4gnUdYdTXhNJuGPqwDZUt43ZjnSqw6k0blVfMoCeDqsHELyGXAN4WQkgAe6SUe60bCCG2A9gOAA888IBPpyV+YvdH7NUlYxVu643DrvzemhWiqkN3rG/HnqNRV1eLPmwZMMrhVyyZh6F4Aus6wtjStcg2FbBnaDTl3gHsGmypa959uB+t4VrTz60P2VCfyZq2+di+djnmhGaY160qSl/Y3Gn72Tn53u3Ysb4dl0cS2LG+PW1txXSRMagaTPwS8seklNeFEAsAHBRCnJdSHtM3SIn7XgCIRCLOtdmkbGQTnlxb47ptr4qJ9Fa0reE6vPJ0F776+vvoHojjeP+wSzqfYXE/0tIAADjSF0P7grlpZfvWa9l3ctBsjLWmrTGj0EixNbIUx/tj6B6IY2n9HDz16YUZTxbHLsTQPTCMlcvmYdvqFvM662tCeHR5I+prQqYrJT6WbtV7Fcj3hkYQjY3jvaERNM+vzei+WAwYVA0mvgi5lPJ66vVDIcQ/AlgF4Jj7XqTSyFahmGs7ADfrzhBLQ6j3n71uBiu/9sZZdA8YYtu5cC7WttsHIHUf+b6Tgzg1OII5oRmuFZ9K/Ne0zce3nvmso0XcUBvCymUN6B6I48roXfR/eCfj55HmhtRNQdjmhh/vH0bnwrmpwifjZmSXlmnnarJrF2D3WebqBvGyPYOqwaRgIRdC1AKYIaW8k/r6CwD+uOCVkZKT7Y/YKo65dE+0O9fKZfNSImc8oE2NmmvEymUNttWkdmzpWoxzV29jS9fiDL+7zrbVzamv3B8IjfmiEtsfX44fXRnFkb4Y9p0cSst0sQu2qq9VlowS8EdaGjBrhjDXpET5eP+URa8XOOmj8Kz58Pp1Od0omYs+/fDDIm8C8I9CCHW8v5VS/pMPxyUVRrYJQkCmiLgJhjXIaPUfu5W8W61glQMOGBbxGz1XEI2N492LcbywuRPfO30ZvTfuoHPhXOw9PohzV2+bYmlFVYM+/9RDeHR5Y2pI85T4Z8sHf+XpLuw7OYS7yQnMCc3C3eQk9h6/aD556CmR3QPDpqWuv28dhaf3hle9W6w93a2fjXVYBt0m1UvBQi6lvAjgMz6shQQMO2GwinsufnUvNwq7c48mkmnWqhLxhtrZONIXw4+unsTI+D0AwL9eu4Ul8+bYiqV+7PjYR3jn/IfoaKrD9seN4h815cjLNdaEZppPAGrU3N3khHmdSuz1IiZrZ0V1TjvfuJ7e6dQDR/VOV+u06/fil0uGlBemHxJXcg1wWsU9n8d5O/Gya3er0K1VAGYxjhLzkfF7WFo/B2MfTWA0cQ+37044tpNVuduJ5CRODRppis2NNRiKJ9Is50RyAonkpNlnZffhASSSE3huU0eai2TFkvtTFj0wJzT152YtoNLf14uFnJ564mNJrGlrtHUh6YKtP/Eocv2d0CVT+VDIiSu5/hFbxd3Oare7ObgNqLDre65enc6hBP/h5ga8+k4/dqxvxz99cAM/unobn1lyv9mP3Io696rmelPAh+IJLGuoQXwsaVrlAFIFPRJ3730MAOarCuQe6YuhvWkutq9tQe+NOxmtAew+C2taptNTz97jFwEYNzGnnjVOri03F4vd78Zr5hIpHxRy4kqhflWnPubWm4NuxVrTEu36nlvb3epWrN4HHABee3YV9hyNYu/xQTz/1EOuNySVu/37T30S99fMxovf/zGSE5M4PTSKvccvorFOXY+RAXPm0iiWq6pLOXXNKpA7Z/YM1IRmoXtgEPvPXoPqCWMNcKrPQg/4qq6Odk891r42hf5OFGo9xy7EEGmeCjj7MciEFA8KOXGlGOloTlWfAFLibVjFytq0BkHt2t1OBfgmbAc46MfX/cNWl00iOWnmbgNA98Awdm5ox8+3NkLvAqkPrJApAZ8TmmGeTwVy9YlId+9NYu8xw5LW3TR2n4nRF73fsX1wPk3H3NA/B3UzPRGNZ5yfAdPKREhZ+tqcSCQie3p6Sn5e4h/FesR26muebciF8m23L5iLvccvOlreyqJUP1ffr+sIm5OJDAS2dC1ybeLlpdmXfj41tKK5sQbf+a2HUV8Tcswn17NevKZh5vM7mWo3MJk2B9XtsyblQwhxxq4pIS1ykhfFCoDp7WmtOdBuQy5UwFMNRHayGK3+XpXCt6VrsemqUdd2994k+m/ewcbOpjTrXe+dYnXjOJ1vY2cTvnf6CpbWz8FQ3BhuUROaZdsNUWWkqJuL3fUCzoM2gMxsGvUZOuWW671fcmlfTCoDCjnJi0Iesd0tR2F5TW9iNccmC8P4+YQ50Dmb9azQU/haw3WmGKvjH7sQS1VvGk2rrAU76hheKiUNH/3FtOvUs1/03jG6+Kt4gF3vGbtGZXo2jZdpRFa3FQkmFHKSF/n4zvWugUq0rMfQqy/1vG0luHbn1At4AJg+8NFE0mwDe6j3ZlowFUivlrQrZNrY2YQXv/8B2hfMxch4Mi0b5auv/wtWLqvPOrhZv4EkkhO4e+/jVLfG5lQ3xFlpI+b0m1R9TSjN/aOv3c6/rh9Przo1/PTXbYuHWJJfHVDIScmYeozPFC2F3urVKSAIpPfqdupJMiV8vaYFrWe7qAwRvSIUmBLl1nAdVi5rMMW6sS6EzoVzzda1K5e5u3H0awaM/uhK2BXWDJSG2lBaZoiysnduaMeWrkXm2tVnZZ/RMpGW8bLnaNSxeIhUBxRyUjK8PsbbWZrWXiJTo9QMt4e1WdbGzibEx5NITnyMFzZ3mimKKointlWip3K8raKsqjF7LsXx/uXbWFo/B7/56AOor73PtKr1SkkAaX5pJcLKJ69yxNUsVCNnUWD34QHzxqU/KehPG63hOtRHQhmFUbq7xc4qdyrlJ9UDhZyUDK+P8dm2ezNVgt8ark31/87cd8/RKPYeu2gKoNrPLmCqCntUgE93s6hqTNWX/MroXVwZvYs//qWfMwVcZXwYlZzz0qYUqcCtygxRI+XOXBo1uzxah0zolapOlbLqaUO121VDMvRtrYFTUr1QyEnZyDeF0Ytl75Srrizhqfczg6u6O2Tb6mZ8/+w1DMUTWDzvZ7CkvgbtC+rSgok7N7ThsdbGVE/0ujRRTiQncObSrVQ+eltaGuKatvlp3Q/VZ6Jb0Nb+KA83N2BdRxg71rfj0eWNiI8lUzeEqTRi/Ua4NTLViGvfyUHbHHsSfCjkpGzkm8LoxbK328aaVqfa1apsF8DwvR+7EMP2x5ebro3H2+djKH4Zv/hzC9FYd19q4PJ9lpuFwIloHHO0SUIGAt0Dw1jTNt8cZLHv5BB2bmhzHB9nF7hVn5VKSQSMaUQqiOk0JENvxJVITuadMuq1hzopDxRyUjbKXSWoXC2t4Vps6VqEhtoQXjrQixPROEKzZmD/2evYfbgf2x9fnioUmuoprkReCZhdf3KD9MHOuw5ewO7D/di5od30r+87OahZ7e0ZU47UU4TKd1eVr+o1WxAzvYmW3Rq9fVZ2NwE21JYlW6sAABUUSURBVKoMKOSkbJQy9c2pGZRqefvSASNoqs/c/N7py8bOUmYEEHf87fs4EY0jPvaRaZ3btYi19lyfcoEYr3r+uLqhKJ/+lHtmquqyviaE9gV1uDcp8SufXYzLIwk83NyQdp1OVZmFfN5ON91Cb8a06P2BQk7KSin+kK3dE9VAiYbaEPb+ZsRMYwSMlMPXnl2FkfEkem/8FADSipCUpaxmfxqj3AYBTDW9susBrtCnGRkYvvnmxhpEY+Np3QzVvtHYGM5dvYWHmxvSOiPeuH0X0dg4Xn2nH689uwpA9grYfHG6CRR6M6ZF7w8UclJWCulX7lX8Vb64GjShBkqMjCex/+w1rFgyD/U1mROIugfieKSlAWcujWBL16K0AKJyd+j9WNTP9FfrevefvZ5qJWBMCzJ88xI3bv0bpESGda385kq8VWfEzkX349+SE5hfd58ZfFUplqrwKJGcSGu7W4mU271WLVDISVnJ5w85V/HXMzfWtM03BW7fySFH61WtR6X3KdcLkBk01fusZGvba3WtqLzvN85cBYAM61ovoNrY2YQVS6Za4e49dhnrOsLYe3zQdO+82XMF21a3mPvWZARfKwtWlvoDhZyUlXz+kHMV//TMjQmz+EaJ6Zq2RmyNLE2rFm0N15nFOd/4h3Op9rZjpv/ailtWhyoK0kvl9SyTrZGluHbrLo72xbBjfbv5vioM0v3myle/pm2+Y7MvdUy970o2qzyfrBT6tysHCjkJHPmIv9P4M/1ra7UoYPjMa0KzcKQvlmaVW3F6SlBW/5q2+QBgBi0BYNfBC1DtehfPm4NLIwm8NzSCzy2rB5A5cBkwXC8NtbPRPTCMte3zbZt9KWHVA7TZPq98slLo364cKOQk8ORiGVpvAvrkG6dqUT2TxQllPT/c3GAZbGxY/d0Dw+hc+Im01EJVAXru6m3sWN+OdR1hbOxsQjQ2hhe//wGWh+uwfW1LmlX96jv9GBm/h9ZwbdYGWHqpfzbyyUrJxy1GK744UMhJ4PHDMnSrFlWZLE6MjCfT5ozqHQq3rW5Bz9BoKstlqvpS9UA5c2kkbb8VS67hwLkbiMbG0T0Qz+hJrkbR/fmvfiarEO4/e808pt6wS12jl8EYbk8/+TwZ5fq7ovB7g0JOAo8fmQ+FBN2m5mzOR/uCOqxYMi+t4derv/45vNlzBfGxJF5+6zxeP30Z3/mth81q0pXL6k1ftxo1t6yhBr/wqZ/Fl1ctTeu5cvRCDNHYOI5eiJkuGCfUMOh3L8Yz+pPrqZLW9rjFJNffFd033qCQk8CTiwgXw8LT+6rsPjyA7WuX42tvnDWDpmp9hk8cGIon8NKBXjy6vDFjsIVefanWp3zgI+NJvHvRyF9X4ux2DXNmGzNETw2OYt/JQWxbbbhpVNaO3ilSv1kUk1xvmExP9MaM7JsQEnxUxeW+k4N4+a3zaT3Bvexn9GUxvt91sA+7Dl4w31PitG11C55/6iH03rhtBkf1/betbsb2x5fjsdZG7FjfnpbNop8DAEYTyYz39p0cxKlBYyj0qcGRrNewbXWLGWQFhCn6uw8PYN/JIXPdKkOnElwX1s9brbES1lbJ0CIn0wK7oRZerHM7d4RT7rk+WUilMVr3/8YXP5lWafr8Uw+Z04uAKat+asyc7lIwqkA/98A81IRmZg1iNtSG8K1nPpvWH92ax+7Vb14qXzVdKflBISfTArtgpj6Jx0k07Co242MfoffGnTQh1YWuviaER5c3or4mZOsaUD71dR3htPeN0XI/BgB8avH9ePzB9J+rxlxK7PVyfidUtadam7X3i1e/uZ3A5iruXnLS1WdKV0puUMjJtMBpLJr+qrCb36kfp7HuPnQPDKYJqdugY7fzKmsXAPafvY7ugWGs6wjjyw8vxf6z17Dv5JDZ/ErvvXLm0ijiY0lPxT7WQiG3z8DJb25XYJSruDMnvXhQyMm0RRdpXYCyiYrdDUDP2a6vCblWVerFSbsOXkDP0AhOROPm4AnlmlEWsdWFc6j3JroH4ugeiKOxzjl4aGflZmvqpX9vbT1gLTByetowWu9OZgRt/c5JJ1NQyAmBvdXqJCp21r1ehfmVJ1o9VVXqRUHrOsLmoIk9R6M40hdDc2MNvvCpn7W1ovWBzU6ta3U3ier4qFvW0diY7dxPO/RWA3pqpdPTRiI5kXEz9DsnnUxBIScEme6OQvu/eKmq1AVZnxakN/lqTLlUdIymXQ+a3+85GjUDsOeu3koTbXUc1fFRt6zPXb1ltva15pFbXST6EGgv/vAtXYstfdhJMaGQE4LMoGCuDaKs4q9b6GryvXVfqyDr76smX16EcGNnE45diGHi448zZnPaHcfqB7fLI1fW/LELMUSaG7Cla1HavgprozG7p4B8KEdFZ5CrSJlHTkgKJUJ2+dluP7Nja2Rpxsg2r/sqEskJ7Ds5hGhsLCOnXOdQ702ciMbx6PL5qYZcwvV86qal3Cl2eeRbI0uxriOME9E4dh/ux6Hem7b53MqX/9KBXgDGzaE1XGs+BeRLvp9ZIZTjnH7hi0UuhHgSwG4AMwH8Lynln/pxXEIUTvnOxajQ9CMYp1vo+QTy9Hx15QIBsgdf1WejAo1OFrJTQFf/XF95usv0vTut3dpQ7FDvTURj4xmplblSjuBnoAOuUsqC/sEQ7yiA5QBCAH4EoNNtn5UrV0pCcuHbPxyQy75+QH77hwO231cb8bGP5DffPi+/+XafHPjwjvz2DwdkfOwj1+3VNtavf+u7p+Syrx+Qv/GX78pvvn0+Yxsdu8/Vadts68jlWnPdZ7oCoEfaaKofFvkqAANSyosAIIT4OwBfAtDrw7EJAeCc7xw068nrk4SaQuR1e93CVsOaE8lJPLfpQbywuROXR3rQPTCM7oFhc2qQ19RKL9a7WpueWqna+apj6Nvlku5JsuOHkC8GoDuVrgJ4xLqREGI7gO0A8MADD/hwWjKdcMt3DhK5iJZeyp9te12ADXcIcDc5gT1Ho2ZHxebGGnyhs8n15ue1cMpubU7iDCDjmnNJ9yTZ8UPIhc17MuMNKfcC2AsAkUgk4+eETAf0jJH0ARSZWEv53axzXYBVGX4iOYmX3zqPR1oa0NxYg6F4Ao119xWckaOLuO4Lzxw1Z3Ra3NK1OO3a9a+t6Z5qlqo1JZO444eQXwWg30qXALjuw3EJqTiyiZ5XUcy1z4sqFHLax66twMh4Mi1Q6hSAnKrGnDDTFt2uLZGcNEVcD6Ba11sTmuk4ANrpiUovkvIyODrIKYN+4oeQvwegXQjRAuAagF8D8Os+HJeQiiOba8T6cyeh0UXPaRur2Lm5IOzWpfLRlYW7pWux61qUBZ/t2lQrgULW64S1ajUb9K+nsIuA5voPwBcBXICRvfIH2bZn1goJKtkyLKw/V1kg33y7z3G/QjNFnLZ3WotTpk8u11YpmSaVso5SgSJmrUBK+QMAP/DjWIRUMtmCrE5WqbX3iG6F55Ipksu6rMfI1jYgl2vz4hoqBUENevsNS/QJ8Rknf7VbH3Bryl6ubgk794z1GKptwIol1zM6E+aK2/roty49FHJCfCQaG8P2v+5BNDYOwLnznxcrPFtAU8fJR253zvhYMi3PPB/cLGH6rUsPhZwQH3npQC+isXG0hmt9ydfWcRNIp33tng52HexL/VSmbbfv5CDu3vsYc2bPyEj9y8XKZl546aGQE+Ijeu+RXNwK+hAIJ8F0E0i3dD6r+FvHvQHGYGfV2wXITP3Lxcqm37r0UMgJ8ZHWcB1ee3aVp23tKiHteoMr/OiT7nwco67vkZYGPLq8IeNmQSu7sqGQE1IC7FwTdmXqdr3BC8GuKtPO4leDnb1Wd2Y7nhVr33LiL+xHTkgJsOt1rfcst/YIL1a2h1PPbSXUqg2utf+53Xtux7Ni7VtO/IUWOSElwIuLoxTZHk4NsPQpQnaNupwqVvXBzm5Y+5YDlW+lBymNkkJOSAnw4t8uhR/arXDo3YtxtDfNxZG+GNa0NWJjZ1Mqw0VkjHrL9aZTXxPCo8sbUV8zJYjKSgd6PccVSkmQ0igp5ISUGK+9VUrF1sjUkObkxMcAgJXL6nGo96aZyVITmpm2Ni/DpXWs8YA3e65gx/p2AEiz0isJvYPjyHiyoq1yCjkhJabSLD192PPGzibTTw8YrQXsRr3pw6Vbn8juFtGfNtT1r+sI44XNnTjUawyorjShzNbBsZKgkBNSYvx0oVj929l8ul6eBnRhfm5Thy/XYJ1hOpVm2etpcEYu1+InQUm7pJATUmL8dKFkm8Tjtn2hwpnvNehPAA83NwCAZxeNTrZryVfo/brOUkIhJyTA2FmMbtZjLhOKrEwNoDD6haty/i1diz2lTdoJ5J6j0ZxcNHbX4tTTPd+bVqW5vrxAISekiBT78d9qMXotn1dtaBPJSddCILt2u4nkRFo5/7mrtz25R6w3ArtsmFzI1lY3X7dIUNwpOhRyQopIpVp3dn3SVSDSzapVOeQA0ixyL9WoKgukZ2gEJ6JxAJnZMIVeT/ZWBNkJijtFh0JOSBGpVOvOrk+6XYqgXcFPQ20oIwjqxS2iskBORON4rLURkebMni5WvD7RBFF8/YRCTkgRqQSBcRNDazaJelXdEBPJCcfMFSfcKjb1c3hxNZXyiSZIlZxWKOSEVDlexDBTxETqJ8Lh587HON4/jO6BYdhVbOZ6YyvWE42fwdFKgEJOSJXjRQytIqZ3Q7T7udsxtq9tweyZwpeKTT+faOzaBgOFB0crAQo5IRWG34/4+fR58TKazu0Yleia0MXbrsVAoTeNcrpm2MaWkArDa2tYP9Hb2Obzc6/beMWpbW4hx9DbBqsWA4d6bxa8VkU5fm8KWuSEVBhBfsT3Cz/81dZjOAV2/aKcvzchpcy+lc9EIhHZ09NT8vMSQoKBH26KIGehOCGEOCOljFjfp2uFEFJx+OGm8cvV44ebp9hQyAmZxgRBpMpNOX3fXqGPnJBpTJBzp0tFEGIWFHJCpjFBECkrXn3ffvnIK6E6Nxt0rRAyjfEzZbBYWN0/+04O4eW3zmPfySHX/YLgEvELCjkhJG+Kke9tJVOQpeXVHj1vvNpjAXStEELyphj53las7p9tq1vMjo1uZOtXXk1QyAkheeOHj72Q+Z9+nKMa8s3pWiGE5E0p8r398HW7nSPb8YPglinIIhdC/CGA/wAglnrrG1LKHxS6KEIIUdhZ035a0dmeCIKQoumHa2WXlPJ/+HAcQgjJwM6V4qe4ZnPVBCFFkz5yQkjgCIK4lhI/fOS/K4Q4J4T4rhCi3mkjIcR2IUSPEKInFos5bUYIIVkpZv671ScehHz0rEIuhDgkhPjA5t+XAPwFgFYAXQBuAHjF6ThSyr1SyoiUMhIOh327AEKIv1R6cK/Y67MKt56PXqlkda1IKTd6OZAQ4i8BHCh4RYSQslLpwb1iry/btKRKpNCslYVSyhupb38ZwAeFL4kQUk7K4X/OJQul2OsLgnBbKTTY+WdCiC4YtbJDAL5S8IoIIWWlHEKWi5UdRKEtNgUJuZTy3/m1EELI9IVZKIXByk5CSMmxBiyL3YWx0gO4hUIhJ4SUnFKn9AUhhbAQWBBECCk5pXalVLvrRkjp3tO3GEQiEdnT01Py8xJCSJARQpyRUkas79O1QgghAYdCTgghAYdCTgghWaj0rBcKOSGEZKHSs16YtUIIIVnwK+ulWGPlaJETQkgW/CpYKpZlT4ucEDKtKOew5WLls9MiJ4RMK8rp7y5WKwJa5ISQaUU1VnlSyAkh04pqbINL1wohhBSZYuehU8gJIaTIFNsvT9cKIYQUmWL75SnkhBBSZIrtl6drhRBCAg6FnBBCAg6FnBBCAg6FnBBCAg6FnBBCAg6FnBBCAg6FnBBCAo6QUpb+pELEAFzy6XDzAQz7dKxyUi3XAVTPtVTLdQDVcy3Vch1AfteyTEoZtr5ZFiH3EyFEj5QyUu51FEq1XAdQPddSLdcBVM+1VMt1AP5eC10rhBAScCjkhBAScKpByPeWewE+US3XAVTPtVTLdQDVcy3Vch2Aj9cSeB85IYRMd6rBIieEkGkNhZwQQgJO4IVcCPFfhRDnhBBnhRBvCyEWlXtN+SKE+HMhxPnU9fyjEGJeudeUD0KIrUKIHwshPhZCBDJVTAjxpBCiTwgxIIT4/XKvJ1+EEN8VQnwohPig3GspBCHEUiHEESHET1L/t3aWe035IIT4GSHEaSHEj1LX8Ue+HDfoPnIhxCeklD9Nff1VAJ1Syt8p87LyQgjxBQDvSCknhBD/HQCklF8v87JyRgjxSQAfA9gD4D9LKXvKvKScEELMBHABwCYAVwG8B+AZKWVvWReWB0KIxwGMAfhrKeWny72efBFCLASwUEr5vhBiLoAzAH4paL8TIYQAUCulHBNCzAbQDWCnlPLdQo4beItciXiKWgCBvTNJKd+WUk6kvn0XwJJyridfpJQ/kVL2lXsdBbAKwICU8qKUMgng7wB8qcxrygsp5TEAI+VeR6FIKW9IKd9PfX0HwE8ALC7vqnJHGoylvp2d+lewZgVeyAFACPEnQogrAH4DwH8p93p84t8DeKvci5imLAagT8m9igCKRrUihGgG8FkAp8q7kvwQQswUQpwF8CGAg1LKgq8jEEIuhDgkhPjA5t+XAEBK+QdSyqUA/gbA75Z3te5ku5bUNn8AYALG9VQkXq4jwAib9wL7pFdNCCHqAPw9gP9keRoPDFLKSSllF4wn7lVCiIJdXoEYviyl3Ohx078F8H8AvFjE5RREtmsRQmwDsBnABlnBAYwcfidB5CoAfdz5EgDXy7QWkiLlU/57AH8jpfyHcq+nUKSUt4QQPwTwJICCgtGBsMjdEEK0a99uAXC+XGspFCHEkwC+DmCLlDJR7vVMY94D0C6EaBFChAD8GoD9ZV7TtCYVJPwOgJ9IKb9Z7vXkixAirLLRhBBzAGyED5pVDVkrfw+gA0aWxCUAvyOlvFbeVeWHEGIAwH0A4qm33g1iBo4Q4pcBvAogDOAWgLNSyl8o76pyQwjxRQD/E8BMAN+VUv5JmZeUF0KI1wF8HkbL1JsAXpRSfqesi8oDIcQaAMcB/CuMv3UA+IaU8gflW1XuCCFWANgH4//VDABvSCn/uODjBl3ICSFkuhN41wohhEx3KOSEEBJwKOSEEBJwKOSEEBJwKOSEEBJwKOSEEBJwKOSEEBJw/j8+iDVizGc3awAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# d2l.set_figsize()\n",
    "plt.scatter(features[:, 1].detach().numpy(), labels.detach().numpy(), 1);## detach会将requires_grad 属性设置为False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "读取数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 1.5107, -0.7532],\n",
      "        [-0.4976, -0.3203],\n",
      "        [-0.2054,  1.2085],\n",
      "        [-0.1446, -0.3844],\n",
      "        [-1.8186,  1.0048],\n",
      "        [-0.0571, -1.3818],\n",
      "        [ 0.8256,  0.0209],\n",
      "        [-1.6343,  1.1719],\n",
      "        [-1.7234,  0.3506],\n",
      "        [ 1.4267, -0.2006]]) \n",
      " tensor([[ 9.7725],\n",
      "        [ 4.2970],\n",
      "        [-0.3211],\n",
      "        [ 5.2255],\n",
      "        [-2.8517],\n",
      "        [ 8.7964],\n",
      "        [ 5.7819],\n",
      "        [-3.0513],\n",
      "        [-0.4562],\n",
      "        [ 7.7238]])\n"
     ]
    }
   ],
   "source": [
    "def data_iter(batch_size, features, labels):\n",
    "    num_examples = len(features)\n",
    "    indices = list(range(num_examples))\n",
    "    # 这些样本是随机读取的，没有特定的顺序\n",
    "    random.shuffle(indices)\n",
    "    for i in range(0, num_examples, batch_size):\n",
    "        batch_indices = torch.tensor(\n",
    "        indices[i: min(i + batch_size, num_examples)])\n",
    "        yield features[batch_indices], labels[batch_indices]\n",
    "batch_size = 10\n",
    "for X, y in data_iter(batch_size, features, labels):\n",
    "    print(X, '\\n', y)\n",
    "    break\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "定义初始化参数模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "w = torch.normal(0, 0.01, size=(2,1), requires_grad=True)\n",
    "b = torch.zeros(1, requires_grad=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " 定义模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def linreg(X, w, b): #@save\n",
    "    \"\"\"线性回归模型\"\"\"\n",
    "    return torch.matmul(X, w) + b"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "定义损失函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def squared_loss(y_hat, y): #@save\n",
    "    \"\"\"均⽅损失\"\"\"\n",
    "    return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "定义优化算法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sgd(params, lr, batch_size): #@save\n",
    "    \"\"\"⼩批量随机梯度下降\"\"\"\n",
    "    with torch.no_grad():\n",
    "        for param in params:\n",
    "            param -= lr * param.grad / batch_size\n",
    "            param.grad.zero_()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "训练过程"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 1, loss 0.033989\n",
      "epoch 2, loss 0.000117\n",
      "epoch 3, loss 0.000049\n",
      "w的估计误差: tensor([ 0.0005, -0.0008], grad_fn=<SubBackward0>)\n",
      "b的估计误差: tensor([8.9645e-05], grad_fn=<RsubBackward1>)\n"
     ]
    }
   ],
   "source": [
    "lr = 0.03\n",
    "num_epochs = 3\n",
    "net = linreg\n",
    "loss = squared_loss\n",
    "for epoch in range(num_epochs):\n",
    "    for X, y in data_iter(batch_size, features, labels):\n",
    "        l = loss(net(X, w, b), y) # X和y的⼩批量损失\n",
    "        # 因为l形状是(batch_size,1)，⽽不是⼀个标量。l中的所有元素被加到⼀起，\n",
    "        # 并以此计算关于[w,b]的梯度\n",
    "        l.sum().backward()\n",
    "        sgd([w, b], lr, batch_size) # 使⽤参数的梯度更新参数\n",
    "    with torch.no_grad():\n",
    "        train_l = loss(net(features, w, b), labels)\n",
    "        print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')\n",
    "print(f'w的估计误差: {true_w - w.reshape(true_w.shape)}')\n",
    "print(f'b的估计误差: {true_b - b}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "线性回归使用深度学习框架的简洁实现"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "true_w = torch.tensor([2, -3.4])\n",
    "true_b = 4.2\n",
    "features, labels = synthetic_data(true_w, true_b, 1000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "读取数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[tensor([[ 0.0098,  0.8733],\n",
       "         [ 0.7275,  2.1755],\n",
       "         [ 1.0279, -0.5770],\n",
       "         [-0.4872, -0.7290],\n",
       "         [-1.5895,  2.0745],\n",
       "         [-0.5214, -1.2732],\n",
       "         [-0.0282,  0.9966],\n",
       "         [-1.0470,  0.0140],\n",
       "         [-0.9552, -0.8857],\n",
       "         [-0.4227,  0.2175]]), tensor([[ 1.2702],\n",
       "         [-1.7391],\n",
       "         [ 8.2260],\n",
       "         [ 5.6991],\n",
       "         [-6.0312],\n",
       "         [ 7.4732],\n",
       "         [ 0.7580],\n",
       "         [ 2.0589],\n",
       "         [ 5.3012],\n",
       "         [ 2.6118]])]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def load_array(data_arrays, batch_size, is_train=True): #@save\n",
    "    \"\"\"构造⼀个PyTorch数据迭代器\"\"\"\n",
    "    dataset = Data.TensorDataset(*data_arrays)\n",
    "    return Data.DataLoader(dataset, batch_size, shuffle=is_train)\n",
    "batch_size = 10\n",
    "data_iter = load_array((features, labels), batch_size)\n",
    "next(iter(data_iter))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用框架预定义好的层"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# nn是神经⽹络的缩写\n",
    "from torch import nn\n",
    "net = nn.Sequential(nn.Linear(2, 1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "初始化模型参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.0164, 0.0010]])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "text/plain": [
       "tensor([0.])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net[0].weight.data.normal_(0, 0.01)\n",
    "net[0].bias.data.fill_(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "定义损失函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss = nn.MSELoss()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "定义优化算法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = torch.optim.SGD(net.parameters(), lr=0.03)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 1, loss 0.000274\n",
      "epoch 2, loss 0.000103\n",
      "epoch 3, loss 0.000102\n"
     ]
    }
   ],
   "source": [
    "num_epochs = 3\n",
    "for epoch in range(num_epochs):\n",
    "    for X, y in data_iter:\n",
    "        l = loss(net(X) ,y)\n",
    "        trainer.zero_grad()\n",
    "        l.backward()\n",
    "        trainer.step()\n",
    "    l = loss(net(features), labels)\n",
    "    print(f'epoch {epoch + 1}, loss {l:f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
